{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 基于CharRNN的藏头诗生成\n",
    "\n",
    "由于RNN自身的结构特性和设计原理，其在处理文本内容时有着天然的优势。\n",
    "在这一个示例中，我们将会介绍在给定句首的情况下，如何使用字符集别的RNN网络来生成工整和韵的藏头诗。当然，除了诗歌之外，char-rnn还可以被用来生成各种蕴含时序结构的内容，包括一段文字、一段使用tex语法书写的文章、一段代码等，详细内容可以参考[这篇博客](http://karpathy.github.io/2015/05/21/rnn-effectiveness/)。\n",
    "\n",
    "### 1. 网络结构\n",
    "本示例中将使用的charRNN网络结构如下:\n",
    "![char-rnn](img/char-rnn.svg)\n",
    "它由两层RNN单元Stack得到，在每一个时间点，输入是一个字符，例如图中输入的第一个字符是“床”，通过两层RNN单元运算后得到一个向量，将这个向量经过一层Dense Connect和softMax得到一个在全字典上的输出概率，在训练的时候，我们期望的输出是“前”，这里直接用交叉熵来作为损失函数；接下来，我们将“前”作为下一个时间点的输入，得到新的字符。而在训练完毕后的生成中，为了结果的多样性，每次可以选择topK概率的字符作为输出。\n",
    "\n",
    "### 2. 数据准备\n",
    "在github上已经有许多完善的古诗词仓库，我们这里选用[chinese-poetry](https://github.com/chinese-poetry/chinese-poetry)中提供的诗词作为训练样本。\n",
    "\n",
    "仓库中提供的诗词均以繁体展示，且被存储为如下所示的json格式。\n",
    ">[{\n",
    "    \"strains\": [\n",
    "      \"仄仄平平仄，平平仄仄平。\", \n",
    "      \"仄平平仄仄，平仄仄平平。\"\n",
    "    ],    \n",
    "    \"author\": \"蓋嘉運\", \n",
    "    \"paragraphs\": [\n",
    "      \"聞道黃花戍，頻年不解兵。\", \n",
    "      \"可憐閨裏月，偏照漢家營。\"\n",
    "    ],    \n",
    "    \"title\": \"雜曲歌辭 伊州 歌第三\"\n",
    "  }, \n",
    "  {\n",
    "    \"strains\": [\n",
    "      \"平仄平平仄，平平仄仄平。\", \n",
    "      \"仄平平仄仄，平仄仄平平。\"\n",
    "    ],    \n",
    "    \"author\": \"蓋嘉運\", \n",
    "    \"paragraphs\": [\n",
    "      \"千里東歸客，無心憶舊遊。\", \n",
    "      \"挂帆游白水，高枕到青州。\"\n",
    "    ],    \n",
    "    \"title\": \"雜曲歌辭 伊州 歌第四\"\n",
    "  }, \n",
    "  {\n",
    "    \"strains\": [\n",
    "      \"仄仄平平仄，平平仄仄平。\", \n",
    "      \"平平平仄仄，仄仄仄平平。\"\n",
    "    ],    \n",
    "    \"author\": \"蓋嘉運\", \n",
    "    \"paragraphs\": [\n",
    "      \"桂殿江烏對，彫屏海燕重。\", \n",
    "      \"秖應多釀酒，醉罷樂高鐘。\"\n",
    "    ],    \n",
    "    \"title\": \"雜曲歌辭 伊州 歌第五\"\n",
    "  },\n",
    "  ....\n",
    "]\n",
    "\n",
    "我们首先定义提取出需要的五言诗的函数："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "\n",
    "def extract_poems(to_file,\n",
    "                  origin_poetry_dir,\n",
    "                  poetry_type=5):\n",
    "    with open(to_file, \"w\") as fw:\n",
    "        filenames = filter(lambda name: name.startswith(\"poet\"),\n",
    "                           os.listdir(origin_poetry_dir))\n",
    "        for filename in filenames:\n",
    "            input_file = origin_poetry_dir + \"/\" + filename\n",
    "            with open(input_file, encoding=\"utf-8\") as fi:\n",
    "                json_data = json.load(fi),\n",
    "                for poem in json_data[0]:\n",
    "                    paragraphs = poem[\"paragraphs\"]\n",
    "                    if len(paragraphs) == 2 and all(map(lambda line: len(line) == (poetry_type + 1) * 2, paragraphs)):\n",
    "                        for two_lines in paragraphs:\n",
    "                            fw.write(\" \".join(two_lines) + \" \\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在网络中字符都需要被embedding为一个向量，在这个示例中，我们使用Word2vector的embedding向量作为char-rnn网络embedding层的初始值"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\users\\dawei\\appdata\\local\\programs\\python\\python36\\lib\\site-packages\\gensim\\utils.py:1197: UserWarning: detected Windows; aliasing chunkize to chunkize_serial\n",
      "  warnings.warn(\"detected Windows; aliasing chunkize to chunkize_serial\")\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from gensim.models import Word2Vec\n",
    "\n",
    "def word2vec(input_file):\n",
    "    with open(input_file, encoding='utf-8') as fi:\n",
    "        lines = fi.readlines()\n",
    "    model = Word2Vec(sentences=lines, size=128, min_count=1)\n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "同时，我们定义一个dataset类来方便进行batch提取和index和word间的转换"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DataSet:\n",
    "    def __init__(self, lines: list, index2word: list):\n",
    "        self._index2word = index2word\n",
    "        self._word2index = {word: i for (i, word) in enumerate(index2word)}\n",
    "        self._index_in_epoch = 0\n",
    "        self._num_examples = len(lines)\n",
    "        self._epochs_completed = False\n",
    "\n",
    "        lines = map(lambda line: [self._word2index[word] for word in line.split(\" \")], lines)\n",
    "        self._data = np.array(list(lines))\n",
    "\n",
    "    def next_batch(self, batch_size):\n",
    "        start = self._index_in_epoch\n",
    "        self._index_in_epoch += batch_size\n",
    "        if self._index_in_epoch > self._num_examples:\n",
    "            # Finished epoch\n",
    "            self._epochs_completed += 1\n",
    "            # Shuffle the data\n",
    "            perm = np.arange(self._num_examples)\n",
    "            np.random.shuffle(perm)\n",
    "            self._data = self._data[perm]\n",
    "            # Start next epoch\n",
    "            start = 0\n",
    "            self._index_in_epoch = batch_size\n",
    "            assert batch_size <= self._num_examples\n",
    "        end = self._index_in_epoch\n",
    "        return self._data[start:end]\n",
    "\n",
    "    def label2word(self, label):\n",
    "        return self._index2word[label]\n",
    "\n",
    "    def word2label(self, word):\n",
    "        return self._word2index[word]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3. 网络定义\n",
    "在我们的char-rnn网络中，定义的placeholder首先包含输入字符input_ids以及标签input_labels。同时，为了适应输入batch的变化（训练时通常用大于1的batch输入，而预测时输入batch一般为1），我们还需要定义一个batch_size的placeholder。考虑到在TensorFLow静态图的特点，在预测时，我们要取出上一步中的hidden state作为输入，因此还需要定义隐藏状态的placeholder。\n",
    "\n",
    "在这个char-rnn中我们使用每两句诗作为一个训练样本，这样既可以在一定程度上保持两句之间的关联性，又能避免利用整首诗时过长的依赖。\n",
    "\n",
    "一些参数和placeholder的定义如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:<tensorflow.python.ops.rnn_cell_impl.BasicLSTMCell object at 0x0000021EF468D208>: Using a concatenated state is slower and will soon be deprecated.  Use state_is_tuple=True.\n",
      "WARNING:tensorflow:<tensorflow.python.ops.rnn_cell_impl.BasicLSTMCell object at 0x0000021EFAB10C50>: Using a concatenated state is slower and will soon be deprecated.  Use state_is_tuple=True.\n",
      "WARNING:tensorflow:<tensorflow.python.ops.rnn_cell_impl.BasicLSTMCell object at 0x0000021EF9741E48>: Using a concatenated state is slower and will soon be deprecated.  Use state_is_tuple=True.\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.contrib.rnn import BasicLSTMCell, MultiRNNCell\n",
    "\n",
    "\n",
    "ORIGIN_POETRY_DIR = \"data/json\"\n",
    "EXTRACTED_POETRY_FILE = \"data/train_poetry.txt\"\n",
    "TRAIN_FILE = \"data/train_poetry_s.txt\"\n",
    "\n",
    "EMBEDDING_SIZE = 128\n",
    "HIDDEN_SIZE = 256\n",
    "BATCH_SIZE = 128\n",
    "STACK_LAYER = 2\n",
    "INPUT_LENGTH = 12\n",
    "\n",
    "\n",
    "input_ids = tf.placeholder(tf.int32, shape=[None, INPUT_LENGTH], name=\"input_id\")\n",
    "input_labels = tf.placeholder(tf.int32, shape=[None, INPUT_LENGTH], name=\"input_label\")\n",
    "input_batch_size = tf.placeholder(tf.int32, shape=())\n",
    "\n",
    "\n",
    "# 选用标准的LSTM作为基本RNN单元，\n",
    "stacked_cells = BasicLSTMCell(num_units=HIDDEN_SIZE, state_is_tuple=False)\n",
    "if STACK_LAYER > 1:\n",
    "    stacked_cells = MultiRNNCell(\n",
    "        [BasicLSTMCell(HIDDEN_SIZE, state_is_tuple=False) for _ in range(STACK_LAYER)],\n",
    "        state_is_tuple=False)\n",
    "initial_zero_states = stacked_cells.zero_state(input_batch_size, tf.float32)\n",
    "initial_states = tf.placeholder(tf.float32, initial_zero_states.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "定义inference过程："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def inference(input_ids, batch_size, initial_state, initial_embeddings,\n",
    "              stacked_cells, hidden_size=HIDDEN_SIZE, input_length=INPUT_LENGTH):\n",
    "\n",
    "    vocab_size, embedding_size = initial_embeddings.shape\n",
    "\n",
    "    # embedding layer\n",
    "    with tf.name_scope(\"embedding\"):\n",
    "        with tf.device(\"/cpu:0\"):\n",
    "            embeddings = tf.get_variable(name=\"embedding\",\n",
    "                                         initializer=tf.truncated_normal(\n",
    "                                             [vocab_size, embedding_size]),\n",
    "                                         dtype=tf.float32)\n",
    "            # 3D array [batch, time_stamp, embedding_feature]\n",
    "            input_tensor = tf.nn.embedding_lookup(embeddings, input_ids)\n",
    "\n",
    "    # rnn layer\n",
    "    with tf.variable_scope(\"rnn\") as vs:\n",
    "        # transpose to [time_stamp, batch, embedding_feature]\n",
    "        input_tensor = tf.transpose(input_tensor, perm=[1, 0, 2])\n",
    "\n",
    "        state = initial_state\n",
    "        states = []\n",
    "        outputs = []\n",
    "        for time in range(input_length):\n",
    "            input_this_time = input_tensor[time]\n",
    "            if time > 0:\n",
    "                vs.reuse_variables()\n",
    "            # output shape: [batch_size, embedding_size]\n",
    "            output, state = stacked_cells(input_this_time, state)\n",
    "            outputs.append(output)\n",
    "            states.append(state)\n",
    "    \n",
    "    # softmax layer\n",
    "    with tf.name_scope(\"softmax\"):\n",
    "        # [batch, embedding_size * time]\n",
    "        outputs = tf.concat(outputs, axis=1)\n",
    "        # [batch * time, embedding_size ]\n",
    "        outputs = tf.reshape(outputs, [-1, hidden_size])\n",
    "        w = tf.Variable(initial_value=tf.truncated_normal(shape=[hidden_size, vocab_size]), name=\"w\")\n",
    "        b = tf.Variable(initial_value=tf.zeros([vocab_size]), name=\"b\")\n",
    "        logits = tf.matmul(outputs, w) + b\n",
    "        logits = tf.reshape(logits, [batch_size, input_length, vocab_size])\n",
    "\n",
    "    return logits, states[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在预测时我们需要的是RNN单元在第一个时间点的输出，由于TensorFLow静态图的特点，要取出每一次运行后的第一个输出和状态，作为下一次运行的输入，因此预测的函数如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict(head_words, sess, logits, first_out_state, input_pl, batch_size_pl, initial_states_pl,\n",
    "            zero_states, dataset, output_length=INPUT_LENGTH):\n",
    "    ret_list = [\"1 \", \"2 \", \"3 \", \"4 \"]\n",
    "    for (i, head_word) in enumerate(head_words):\n",
    "        if i % 2 != 0:\n",
    "            continue\n",
    "        state = zero_states\n",
    "        next_input = head_word\n",
    "        ret_list[i] += next_input\n",
    "        for j in range(output_length):\n",
    "            input_index = dataset.word2label(next_input)\n",
    "            input_batch = np.zeros([1, output_length], dtype=np.int32)\n",
    "            input_batch[0][0] = input_index\n",
    "            logits_val, state = sess.run([logits, first_out_state],\n",
    "                                         feed_dict={input_pl: input_batch,\n",
    "                                                    batch_size_pl: 1,\n",
    "                                                    initial_states_pl: state})\n",
    "\n",
    "            # 下半句时，需要手动指定\n",
    "            if j == output_length // 2 - 1:\n",
    "                next_input = head_words[i+1]\n",
    "                i += 1\n",
    "            else:\n",
    "                next_input = dataset.label2word(np.argmax(logits_val[0][0]))\n",
    "\n",
    "            if next_input == \"\\n\":\n",
    "                ret_list[i] += \"\\\\n\"\n",
    "            else:\n",
    "                ret_list[i] += next_input\n",
    "    return ret_list"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "提取五言诗数据中的五言诗，json文件较大，可以自行下载，本次使用已经提取好了数据，在data目录下。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "#extract_poems(to_file=EXTRACTED_POETRY_FILE, origin_poetry_dir=ORIGIN_POETRY_DIR, poetry_type=5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "为了看起来方便，使用opencc将繁体字转换为简体，可以使用任何其他工具将其转换为简体，在data目录下，已经转换好了。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "#os.system(\"opencc -i {0} -o {1} -c t2s\".format(EXTRACTED_POETRY_FILE, TRAIN_FILE))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\users\\dawei\\appdata\\local\\programs\\python\\python36\\lib\\site-packages\\ipykernel_launcher.py:3: DeprecationWarning: Call to deprecated `syn0` (Attribute will be removed in 4.0.0, use self.wv.vectors instead).\n",
      "  This is separate from the ipykernel package so we can avoid doing imports until\n"
     ]
    }
   ],
   "source": [
    "w2v_model = word2vec(TRAIN_FILE)\n",
    "\n",
    "initial_embeddings = w2v_model.wv.syn0\n",
    "\n",
    "logits, first_state = inference(input_ids, input_batch_size, initial_states,\n",
    "                                               initial_embeddings, stacked_cells)\n",
    "loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=input_labels, logits=logits))\n",
    "\n",
    "train_op = tf.train.AdamOptimizer(0.001).minimize(loss)\n",
    "\n",
    "init_op = tf.global_variables_initializer()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "训练过程"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iter 0/100000 with loss: 8.87\n",
      "1 深酽序宠扭扭\n",
      "2 度眦彭彭摸噉噉\n",
      "3 学觔暇速哇猖\n",
      "4 习奉斿彷彷彷彷\n",
      "===========================\n",
      "iter 100/100000 with loss: 5.79\n",
      "1 深，不有有有\n",
      "2 度。\\n。\\n。。\n",
      "3 学不可，不有\n",
      "4 习有人。\\n。\\n\n",
      "===========================\n",
      "iter 200/100000 with loss: 5.35\n",
      "1 深无人不见，\n",
      "2 度不见如来。\\n\n",
      "3 学不知见，不\n",
      "4 习见如来。\\n。\n",
      "===========================\n",
      "iter 300/100000 with loss: 5.28\n",
      "1 深无人不知，\n",
      "2 度人不知人。\\n\n",
      "3 学无人不知，\n",
      "4 习有不知人。\\n\n",
      "===========================\n",
      "iter 400/100000 with loss: 5.28\n",
      "1 深无人不见，\n",
      "2 度人不见人。\\n\n",
      "3 学无人不见，\n",
      "4 习有不见人。\\n\n",
      "===========================\n",
      "iter 500/100000 with loss: 5.2\n",
      "1 深人不见，不\n",
      "2 度不见人。\\n。\n",
      "3 学无人不见，\n",
      "4 习有不知人。\\n\n",
      "===========================\n",
      "iter 600/100000 with loss: 5.07\n",
      "1 深无不见，不\n",
      "2 度不知时。\\n。\n",
      "3 学生无事事，\n",
      "4 习爲不知人。\\n\n",
      "===========================\n",
      "iter 700/100000 with loss: 5.16\n",
      "1 深知不见，不\n",
      "2 度未得时。\\n。\n",
      "3 学人不可见，\n",
      "4 习爲不可得。\\n\n",
      "===========================\n",
      "iter 800/100000 with loss: 4.83\n",
      "1 深山中有人，\n",
      "2 度月不知时。\\n\n",
      "3 学道人不见，\n",
      "4 习有不知人。\\n\n",
      "===========================\n",
      "iter 900/100000 with loss: 4.87\n",
      "1 深山中人间，\n",
      "2 度月满天下。\\n\n",
      "3 学道人不见，\n",
      "4 习与一声声。\\n\n",
      "===========================\n",
      "iter 1000/100000 with loss: 4.88\n",
      "1 深山中有客，\n",
      "2 度我自何之。\\n\n",
      "3 学道人不见，\n",
      "4 习见不知人。\\n\n",
      "===========================\n",
      "iter 1100/100000 with loss: 4.62\n",
      "1 深山不可见，\n",
      "2 度头不能来。\\n\n",
      "3 学道人不得，\n",
      "4 习爲不知音。\\n\n",
      "===========================\n",
      "iter 1200/100000 with loss: 4.71\n",
      "1 深山如何处，\n",
      "2 度君满山中。\\n\n",
      "3 学道如何处，\n",
      "4 习是一年春。\\n\n",
      "===========================\n",
      "iter 1300/100000 with loss: 4.67\n",
      "1 深山中人少，\n",
      "2 度月照人间。\\n\n",
      "3 学道人不见，\n",
      "4 习上有馀人。\\n\n",
      "===========================\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-10-ded00a939cdc>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m()\u001b[0m\n\u001b[0;32m     22\u001b[0m                                             \u001b[0minput_labels\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mlabel_batch\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     23\u001b[0m                                             \u001b[0minput_batch_size\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mBATCH_SIZE\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 24\u001b[1;33m                                             initial_states: get_zero_states(BATCH_SIZE)})\n\u001b[0m\u001b[0;32m     25\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0miter\u001b[0m \u001b[1;33m%\u001b[0m \u001b[1;36m100\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     26\u001b[0m             \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"iter {0}/{1} with loss: {2:.3}\"\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0miter\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmax_iter\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mloss_val\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\users\\dawei\\appdata\\local\\programs\\python\\python36\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36mrun\u001b[1;34m(self, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[0;32m    898\u001b[0m     \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    899\u001b[0m       result = self._run(None, fetches, feed_dict, options_ptr,\n\u001b[1;32m--> 900\u001b[1;33m                          run_metadata_ptr)\n\u001b[0m\u001b[0;32m    901\u001b[0m       \u001b[1;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    902\u001b[0m         \u001b[0mproto_data\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\users\\dawei\\appdata\\local\\programs\\python\\python36\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_run\u001b[1;34m(self, handle, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[0;32m   1133\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[0mfinal_fetches\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0mfinal_targets\u001b[0m \u001b[1;32mor\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mhandle\u001b[0m \u001b[1;32mand\u001b[0m \u001b[0mfeed_dict_tensor\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1134\u001b[0m       results = self._do_run(handle, final_targets, final_fetches,\n\u001b[1;32m-> 1135\u001b[1;33m                              feed_dict_tensor, options, run_metadata)\n\u001b[0m\u001b[0;32m   1136\u001b[0m     \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1137\u001b[0m       \u001b[0mresults\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\users\\dawei\\appdata\\local\\programs\\python\\python36\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_do_run\u001b[1;34m(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)\u001b[0m\n\u001b[0;32m   1314\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[0mhandle\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1315\u001b[0m       return self._do_call(_run_fn, feeds, fetches, targets, options,\n\u001b[1;32m-> 1316\u001b[1;33m                            run_metadata)\n\u001b[0m\u001b[0;32m   1317\u001b[0m     \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1318\u001b[0m       \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_do_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0m_prun_fn\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfeeds\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfetches\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\users\\dawei\\appdata\\local\\programs\\python\\python36\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_do_call\u001b[1;34m(self, fn, *args)\u001b[0m\n\u001b[0;32m   1320\u001b[0m   \u001b[1;32mdef\u001b[0m \u001b[0m_do_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1321\u001b[0m     \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1322\u001b[1;33m       \u001b[1;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1323\u001b[0m     \u001b[1;32mexcept\u001b[0m \u001b[0merrors\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mOpError\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1324\u001b[0m       \u001b[0mmessage\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mcompat\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mas_text\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0me\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmessage\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\users\\dawei\\appdata\\local\\programs\\python\\python36\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_run_fn\u001b[1;34m(feed_dict, fetch_list, target_list, options, run_metadata)\u001b[0m\n\u001b[0;32m   1305\u001b[0m       \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_extend_graph\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1306\u001b[0m       return self._call_tf_sessionrun(\n\u001b[1;32m-> 1307\u001b[1;33m           options, feed_dict, fetch_list, target_list, run_metadata)\n\u001b[0m\u001b[0;32m   1308\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1309\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0m_prun_fn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mhandle\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\users\\dawei\\appdata\\local\\programs\\python\\python36\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_call_tf_sessionrun\u001b[1;34m(self, options, feed_dict, fetch_list, target_list, run_metadata)\u001b[0m\n\u001b[0;32m   1407\u001b[0m       return tf_session.TF_SessionRun_wrapper(\n\u001b[0;32m   1408\u001b[0m           \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_session\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0moptions\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtarget_list\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1409\u001b[1;33m           run_metadata)\n\u001b[0m\u001b[0;32m   1410\u001b[0m     \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1411\u001b[0m       \u001b[1;32mwith\u001b[0m \u001b[0merrors\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mraise_exception_on_not_ok_status\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0mstatus\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ],
     "output_type": "error"
    }
   ],
   "source": [
    "# 载入数据\n",
    "with open(TRAIN_FILE, encoding='utf-8') as fi:\n",
    "    lines = fi.readlines()\n",
    "data = DataSet(lines, w2v_model.wv.index2word)\n",
    "\n",
    "def get_zero_states(batch_size):\n",
    "    return np.zeros(shape=[batch_size, initial_zero_states.shape[1]], dtype=np.float32)\n",
    "\n",
    "# 最大迭代次数\n",
    "max_iter = 100000\n",
    "\n",
    "config = tf.ConfigProto()\n",
    "config.gpu_options.per_process_gpu_memory_fraction = 0.4\n",
    "with tf.Session(graph=loss.graph, config=config) as sess:\n",
    "    sess.run(init_op)\n",
    "    for iter in range(max_iter):\n",
    "        input_batch = data.next_batch(BATCH_SIZE)\n",
    "        label_batch = input_batch[:, 1:]\n",
    "        input_batch = input_batch[:, :-1]\n",
    "        *_, loss_val = sess.run([train_op, logits, loss],\n",
    "                                feed_dict={ input_ids: input_batch,\n",
    "                                            input_labels: label_batch,\n",
    "                                            input_batch_size: BATCH_SIZE,\n",
    "                                            initial_states: get_zero_states(BATCH_SIZE)})\n",
    "        if iter % 100 == 0:\n",
    "            print(\"iter {0}/{1} with loss: {2:.3}\".format(iter, max_iter, loss_val))\n",
    "            head_words = \"深度学习\"\n",
    "            ret = predict(head_words, sess,\n",
    "                          logits, first_state,                          # output tensor of inference\n",
    "                          input_ids, input_batch_size, initial_states,  # place holder\n",
    "                          get_zero_states(1), data)\n",
    "            for line in ret:\n",
    "                print(line)\n",
    "            print(\"===========================\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
